import xml.etree.ElementTree as ET
from collections.abc import Sequence
from typing import Generic, Optional, cast

from langchain_core.agents import AgentAction
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import (
    AIMessage as LangchainAssistantMessage,
    BaseMessage,
    merge_message_runs,
)
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.runnables import RunnableConfig

from posthog.schema import ArtifactContentType, ArtifactSource, VisualizationArtifactContent

from posthog.models.group_type_mapping import GroupTypeMapping

from ee.hogai.core.node import AssistantNode
from ee.hogai.llm import MaxChatOpenAI
from ee.hogai.utils.feature_flags import has_agent_modes_feature_flag
from ee.hogai.utils.helpers import find_start_message
from ee.hogai.utils.types import AssistantState, IntermediateStep, PartialAssistantState
from ee.hogai.utils.types.base import ArtifactRefMessage

from .parsers import PydanticOutputParserException, parse_pydantic_structured_output
from .prompts import (
    FAILOVER_OUTPUT_PROMPT,
    FAILOVER_PROMPT,
    GROUP_MAPPING_PROMPT,
    NEW_PLAN_PROMPT,
    PLAN_PROMPT,
    QUESTION_PROMPT,
)
from .utils import Q, SchemaGeneratorOutput

RETRIES_ALLOWED = 2


class SchemaGenerationException(Exception):
    """An error occurred while generating a schema in the `SchemaGeneratorNode` node."""

    def __init__(self, llm_output: str, validation_message: str):
        super().__init__("Failed to generate schema")
        self.llm_output = llm_output
        self.validation_message = validation_message


class SchemaGeneratorNode(AssistantNode, Generic[Q]):
    INSIGHT_NAME: str
    """
    Name of the insight type used in the exception messages.
    """
    OUTPUT_MODEL: type[SchemaGeneratorOutput[Q]]
    """Pydantic model of the output to be generated by the LLM."""
    OUTPUT_SCHEMA: dict
    """JSON schema of OUTPUT_MODEL for LLM's use."""

    @property
    def _model(self):
        return MaxChatOpenAI(
            model="gpt-5.1",
            temperature=0.3,
            disable_streaming=True,
            user=self._user,
            team=self._team,
            max_tokens=8192,
            billable=True,
            output_version="responses/v1",
            use_responses_api=True,
            reasoning={
                "effort": "none",
            },
            model_kwargs={
                "prompt_cache_key": f"team_{self._team.id}",
            },
        ).with_structured_output(
            self.OUTPUT_SCHEMA,
            method="json_schema",
            include_raw=False,
        )

    def _parse_output(self, output: dict) -> SchemaGeneratorOutput[Q]:
        """This can raise a PydanticOutputParserException if the output is not parsable (therefore unusable)."""
        return parse_pydantic_structured_output(self.OUTPUT_MODEL)(output)

    async def _quality_check_output(self, output: SchemaGeneratorOutput[Q]) -> None:
        """
        If implemented, this can raise a PydanticOutputParserException exception if something's off about the output
        (e.g. a non-existent table field is used).

        Raising here means that the LLM should iterate on the output, but also that it's still usable
        if we aren't able to resolve the issue in a couple attempts.
        """
        pass

    async def _run_with_prompt(
        self,
        state: AssistantState,
        prompt: ChatPromptTemplate,
        config: Optional[RunnableConfig] = None,
    ) -> PartialAssistantState:
        generated_plan = state.plan or ""
        intermediate_steps: Sequence[IntermediateStep] = state.intermediate_steps or []
        validation_error_message = intermediate_steps[-1][1] if intermediate_steps else None

        message_history = await self._construct_messages(state, validation_error_message=validation_error_message)
        generation_prompt = prompt + message_history
        merger = merge_message_runs()

        chain = generation_prompt | merger | self._model | self._parse_output

        try:
            result: SchemaGeneratorOutput[Q] = await chain.ainvoke(
                {
                    "project_datetime": self.project_now,
                    "project_timezone": self.project_timezone,
                    "project_name": self._team.name,
                },
                config,
            )
            # If quality check raises, we will still iterate if we've got any attempts left,
            # however if we don't have any more attempts, we're okay to use `result` (instead of throwing)
            await self._quality_check_output(cast(SchemaGeneratorOutput[Q], result))
        except (PydanticOutputParserException, OutputParserException) as e:
            # Try again with feedback a couple times
            if len(intermediate_steps) < RETRIES_ALLOWED:
                return PartialAssistantState(
                    intermediate_steps=[
                        *intermediate_steps,
                        (
                            AgentAction(
                                "handle_incorrect_response",
                                e.llm_output or "No input was provided.",
                                e.validation_message
                                if isinstance(e, PydanticOutputParserException)
                                else "The provided JSON was invalid.",
                            ),
                            None,
                        ),
                    ],
                    query_generation_retry_count=len(intermediate_steps) + 1,
                )

            if isinstance(e, PydanticOutputParserException):
                raise SchemaGenerationException(e.llm_output, e.validation_message)
            raise SchemaGenerationException(e.llm_output or "No input was provided.", str(e))

        # We've got a result that either passed the quality check or we've exhausted all attempts at iterating - return
        # Create an artifact with the visualization content
        artifact = await self.context_manager.artifacts.create(
            content=VisualizationArtifactContent(
                query=result.query,
                description=generated_plan or None,
            ),
            name=state.visualization_title or "Visualization",
        )
        artifact_message = self.context_manager.artifacts.create_message(
            artifact_id=artifact.short_id,
            source=ArtifactSource.ARTIFACT,
            content_type=ArtifactContentType.VISUALIZATION,
        )

        return PartialAssistantState(
            messages=[artifact_message],
            intermediate_steps=None,
            plan=None,
            rag_context=None,
            query_generation_retry_count=len(intermediate_steps),
        )

    def router(self, state: AssistantState):
        if state.intermediate_steps:
            return "tools"
        return "next"

    async def _get_group_mapping_prompt(self) -> str:
        groups = GroupTypeMapping.objects.filter(project_id=self._team.project_id).order_by("group_type_index")
        group_names = [f'name "{group.group_type}", index {group.group_type_index}' async for group in groups]
        if not group_names:
            return "The user has not defined any groups."

        root = ET.Element("list of defined groups")
        root.text = "\n" + "\n".join(group_names) + "\n"
        return ET.tostring(root, encoding="unicode")

    async def _construct_messages(
        self, state: AssistantState, validation_error_message: str | None = None
    ) -> list[BaseMessage]:
        """
        Reconstruct the conversation for the generation. Take all previously generated questions, plans, and schemas, and return the history.
        """
        # Only process the last five artifact messages.
        artifact_messages = await self.context_manager.artifacts.aenrich_messages(
            [message for message in state.messages if isinstance(message, ArtifactRefMessage)][-5:]
        )
        generated_plan = state.plan

        # Add the group mapping prompt to the beginning of the conversation.
        group_mapping = await self._get_group_mapping_prompt()
        conversation: list[BaseMessage] = [
            HumanMessagePromptTemplate.from_template(GROUP_MAPPING_PROMPT, template_format="mustache").format(
                group_mapping=group_mapping
            )
        ]

        # Batch fetch all artifact contents (pass full state.messages for State source lookup)
        artifact_contents = await self.context_manager.artifacts.aget_contents_by_message_id(state.messages)

        for message in artifact_messages:
            content = artifact_contents.get(message.id or "")
            if not content:
                continue
            plan = content.description or ""
            query = content.name or ""
            answer = content.query

            # Plans go first.
            conversation.append(
                HumanMessagePromptTemplate.from_template(PLAN_PROMPT, template_format="mustache").format(plan=plan)
            )
            # Then questions.
            conversation.append(
                HumanMessagePromptTemplate.from_template(QUESTION_PROMPT, template_format="mustache").format(
                    question=query
                )
            )
            # Then the answer.
            if answer:
                conversation.append(LangchainAssistantMessage(content=answer.model_dump_json()))

        # Add the initiator message and the generated plan to the end, so instructions are clear.
        if generated_plan:
            prompt = NEW_PLAN_PROMPT if artifact_messages else PLAN_PROMPT
            conversation.append(
                HumanMessagePromptTemplate.from_template(prompt, template_format="mustache").format(
                    plan=generated_plan or ""
                )
            )
        conversation.append(
            HumanMessagePromptTemplate.from_template(QUESTION_PROMPT, template_format="mustache").format(
                question=self._get_insight_plan(state)
            )
        )

        # Retries must be added to the end of the conversation.
        if validation_error_message:
            conversation.append(
                HumanMessagePromptTemplate.from_template(FAILOVER_PROMPT, template_format="mustache").format(
                    validation_error_message=validation_error_message
                )
            )

        return conversation

    def _get_insight_plan(self, state: AssistantState) -> str:
        if state.root_tool_insight_plan:
            return state.root_tool_insight_plan
        start_message = find_start_message(state.messages, state.start_id)
        if start_message:
            return start_message.content
        return ""

    def _has_agent_modes_feature_flag(self) -> bool:
        return has_agent_modes_feature_flag(self._team, self._user)


class SchemaGeneratorToolsNode(AssistantNode):
    """
    Used for failover from generation errors.
    """

    async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
        intermediate_steps = state.intermediate_steps or []
        if not intermediate_steps:
            return None

        action, _ = intermediate_steps[-1]
        prompt = (
            ChatPromptTemplate.from_template(FAILOVER_OUTPUT_PROMPT, template_format="mustache")
            .format_messages(output=action.tool_input, exception_message=action.log)[0]
            .content
        )

        return PartialAssistantState(
            intermediate_steps=[
                *intermediate_steps[:-1],
                (action, str(prompt)),
            ]
        )
